import argparse
import torch
import os
import re
import json
from tqdm import tqdm
import shortuuid
import requests
from io import BytesIO
import torch.multiprocessing as mp
from torch.multiprocessing import Process, Manager


from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init

from tci_attn import LlamaAttentionWithLogits
from PIL import Image
import math

from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)


def load_image(image_file: str) -> Image.Image:
    if image_file.startswith("http://") or image_file.startswith("https://"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def load_images(image_files: list[str]) -> list[Image.Image]:
    out = []
    for image_file in image_files:
        img = load_image(image_file)
        out.append(img)
    return out

def image_parser(image_arg: str, sep: str) -> list[str]:
    return image_arg.split(sep)

def eval_model(rank, args_dict, shared_results):
    # set gpu
    visible_device = args_dict['gpu_ids'][rank]
    os.environ["CUDA_VISIBLE_DEVICES"] = str(visible_device)
    device = torch.device(f"cuda:0")
    
    disable_torch_init()

    # 1. load model
    model_name = get_model_name_from_path(args_dict['model_path'])
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=args_dict['model_path'],
        model_base=args_dict['model_base'],
        model_name=model_name,
        torch_dtype=torch.float16 if 'torch_dtype' in args_dict else torch.float16,
    )
    model.to(device)

    total_gpus = len(args_dict['gpu_ids'])
    data_shard = args_dict['data'][rank::total_gpus]
    
    results = []

    # reset attention modules in model 
    if args_dict['tci'] == True:
        for i, layer in enumerate(model.model.layers):
            if i in [0, 1, 14, 15, 17]:    
                attn_adap = LlamaAttentionWithLogits(layer.self_attn.config, layer_idx=i, alpha=args_dict['alpha'])
                attn_adap.load_state_dict(layer.self_attn.state_dict())
                attn_adap = attn_adap.half().to(device)
                layer.self_attn = attn_adap

    for data in tqdm(data_shard, desc=f"GPU {visible_device} Processing"):
        image_file = os.path.join(args_dict['image_folder'], data["image"])

        negative_prompt = data['text']
        image = load_image(image_file)

        images_tensor = process_images([image], image_processor, model.config).to(
            device, dtype=torch.float16
        )
        
        qs = negative_prompt
        image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
        if IMAGE_PLACEHOLDER in qs:
            if model.config.mm_use_im_start_end:
                qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
            else:
                qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
        else:
            if model.config.mm_use_im_start_end:
                qs = image_token_se + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        conv_mode = args_dict['conv_mode']
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        
        input_ids = (
            tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
            .unsqueeze(0)
            .to(device)
        )

        # generate
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=images_tensor,
                image_sizes=[image.size],
                do_sample=True if args_dict['temperature'] > 0 else False,
                temperature=args_dict['temperature'],
                top_p=args_dict['top_p'],
                top_k=args_dict['top_k'], # vcd 
                # num_beams=args_dict['num_beams'],
                max_new_tokens=args_dict['max_new_tokens'],
                use_cache=True,
            )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

        result = {
            "question_id": data["question_id"],
            "image": data["image"],
            "text": data["text"],
            "label": data["label"],
            "model_answer": outputs
        }
        results.append(result)
    
    shared_results.extend(results)
    print(f"GPU {visible_device} finished processing {len(results)} samples")

def process_json_file(json_file_path, args_base, output_folder):
    """
    process single file and get results
    """
    file_name = os.path.basename(json_file_path)
    match = re.search(r'coco_pope_(.+)\.json', file_name)
    if match:
        feature = match.group(1)
    else:
        feature = "unknown"
        print(f"WARNING cant get feature form {file_name} , will use 'unknown'")
    
    # load data
    with open(json_file_path, "r") as f:
        try:
            data_list = [json.loads(line) for line in f]
        except:
            f.seek(0)
            data_list = json.load(f)
    
    args_dict = args_base.copy()
    args_dict['data'] = data_list
    
    mp.set_start_method('spawn', force=True)
    with Manager() as manager:
        shared_results = manager.list()
        
        processes = []
        for rank in range(len(args_dict['gpu_ids'])):
            p = Process(target=eval_model, args=(rank, args_dict, shared_results))
            p.start()
            processes.append(p)
        
        for p in processes:
            p.join()
        
        final_results = list(shared_results)
        
        output_file = os.path.join(output_folder, f"res_pope_{feature}_W+{args_dict['alpha']}_V2.json")
        with open(output_file, "w") as f:
            json.dump(final_results, f, indent=4)
    
    print(f"Done: {file_name} -> {os.path.basename(output_file)}, total {len(final_results)} samples")
    return output_file

def main():
    model_path = "llava-v1.5-7b"
    image_folder = "COCO_val2014"
    output_folder = "results/pope/POPE_W+_V2/W+4_V2"
    input_folder = "tci/POPE"
    
    gpu_ids = [0,1,2,3,4,5,6,7]  

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    json_files = []
    for file in os.listdir(input_folder):
        if file.startswith("coco_pope_") and file.endswith(".json"):
            json_files.append(os.path.join(input_folder, file))
    
    if not json_files:
        print(f"WARNING: can not  find josn file in {input_folder} (coco_pope_*.json)")
        return
    
    print(f"find {len(json_files)} JSON file to process")

    args_base = {
    'model_path' : model_path,
    'model_base' : None,
    'conv_mode' : "vicuna_v1",
    'image_folder' : image_folder,
    'output_folder' : output_folder,
    'sep' : ",",
    'temperature' : 0.0, # vcd:1
    'top_p' : 1,
    'top_k' : None, # vcd
    'num_beams' : 1,
    'max_new_tokens' : 512, # 512,1024 FOR VCD
    'torch_dtype' : torch.float16,
    'tci' : True,
    'alpha' : 4,
    'gpu_ids' : gpu_ids  
    }

    output_files = []
    for json_file in json_files:
        output_file = process_json_file(json_file, args_base, output_folder)
        output_files.append(output_file)
    
    print("\n===== ALL DONE =====")
    for input_file, output_file in zip(json_files, output_files):
        print(f"input: {os.path.basename(input_file)} -> output: {os.path.basename(output_file)}")
    
if __name__ == "__main__":
    main()